{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as Data\n",
    "torch.manual_seed(8) # for reproduce\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import gc\n",
    "import sys\n",
    "sys.setrecursionlimit(50000)\n",
    "import pickle\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "# from tensorboardX import SummaryWriter\n",
    "torch.nn.Module.dump_patches = True\n",
    "import copy\n",
    "import pandas as pd\n",
    "#then import my own modules\n",
    "from AttentiveFP import Fingerprint, Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import r2_score\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import auc\n",
    "from sklearn.metrics import f1_score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from rdkit.Chem import rdMolDescriptors, MolSurf\n",
    "# from rdkit.Chem.Draw import SimilarityMaps\n",
    "from rdkit import Chem\n",
    "# from rdkit.Chem import AllChem\n",
    "from rdkit.Chem import QED\n",
    "%matplotlib inline\n",
    "from numpy.polynomial.polynomial import polyfit\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib\n",
    "from IPython.display import SVG, display\n",
    "import seaborn as sns; sns.set(color_codes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of all smiles:  41127\n",
      "number of successfully processed smiles:  41127\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU8AAAC/CAYAAAB+KF5fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGFdJREFUeJzt3X9M03f+B/BnCwjyw5YzHX4PuKkrbVZRcONnNGNDce4CGnbLeSy5k807PXPMWwZungaiF3NngIjLiGeQmG06b3dZNitwAWXc8G4gZ/RwOhjlhzOEC1CEgghWoP3+YfiMWgrtxxZKfT4SI7zf7777fqXtk08/n08/lZjNZjOIiMgh0vleABHRQsTwJCISgeFJRCQCw5OISASGJxGRCAxPIiIRGJ5ERCIwPImIRGB4EhGJwPAkIhKB4UlEJALDk4hIBIYnEZEI3vO9AHc1MHAPJpPtC04tXRqIO3eG53BFrudpNbEe9+cONUmlEgQHBzh8O4anDSaTecbwnBzjaTytJtbj/hZqTXzbTkQkAsOTiEgEhicRkQgMTyIiEXjAyA2MmwDj2LjNfl8fb3jzzxyRW2F4ugHj2DiuNPfY7I99NgTevnyoiNwJt2eIiERgeBIRicDwJCISgeFJRCQCw5OISASGJxGRCAxPIiIR7ArP7u5uHD58GBkZGVi7di3UajUaGhqsxiUnJ0OtVlv9KywstBrb19eH9957D/Hx8YiOjsbrr7+Oa9euTXv/ZWVl2LJlC1avXo0XXngBhYWFMBqNjzUnEdHjsOvM69u3b6OiogIajQYJCQmoqamxOTY2NhY5OTkWbSEhIRa/G41GZGZmYmRkBLm5uZDL5fjoo4+QmZmJTz/9FBqNRhir1Wrx7rvvIiMjA/v370d7ezsKCwvR1dWFoqIiUXMSET0uu8IzNjYW9fX1AIDq6uoZw3PJkiWIjo6ecb7PPvsMra2t+Pzzz7Fq1SoAQFxcHF555RUcPXoUpaWlAICJiQkUFBQgOTkZBw8eBAAkJCTAx8cHubm5yMzMRFRUlENzEhE5g11v26VS5+4ara6uhkqlEkIOABYtWoTU1FTU1dVhePjhlaUbGxuh1+uRnp5ucfu0tDT4+PigqqrK4TmJiJzB6QeMLl++jLVr1yIyMhJpaWk4e/YszGbLK0W3trZCpVJZ3VatVmNiYgIdHR3COACIiIiwGLd48WKEh4cL/Y7MSUTkDE692sSLL76IyMhIhIeHw2Aw4Pz58zh06BC+//577N+/XxhnMBggk8msbj/ZNjAwIIyb2v7o2Ml+R+a019KlgbOOUSiCHJrTFnP/CIIC/Wz2+/v7QvEjf6fc12ycVZO7YD3ub6HW5NTwzMvLs/g9JSUF2dnZOH36NLZv347Q0FChTyKR2Jzn0T5bY+0dN1vfdO7cGZ7xu1UUiiDo9XcdmtOWEeM47g7ft90/YoR+YsIp9zUTZ9bkDliP+3OHmqRSiV0bS1a3c8FaLKSnp8NkMuGbb74R2uRyucVW46TBwUGhf+r/tsZO3dK0d04iImdweXiaTKaHdzTloJNSqYROp7Ma29LSAi8vL6xcuVIYB8Bi3yYAjI6OorOz02JfqL1zEhE5g8vDU6vVQiqVYvXq1UJbSkoKdDodmpubhbYHDx6goqICiYmJCAx8uAkdHR0NhUIBrVZrMWd5eTnGxsawadMmh+ckInIGu/d5VlZWAgBu3LgBALhy5QoGBgawePFiJCUloby8HF9++SWSkpKwbNkyDA4O4vz586iursaOHTvw4x//WJjrtddewyeffIKsrCxkZ2dDJpPh448/Rm9vL44dO/bD4ry9kZ2djX379uGPf/wjXn75ZeEk+ZdfftnifFJ75yQicgaJ+dHziGxQq9XTtoeGhqKmpgaNjY04duwY2traYDAY4OPjA7VajW3btlmdpwkAer0e+fn5qK2thdFohEajQXZ2NmJiYqzGarValJaW4tatWwgODkZaWhr27NkDPz8/0XPOZi4PGN0zzv41HAFz8DUc7rDz3plYj/tzh5rEHjCyOzyfNAzPhY/1uD93qMltj7YTEXkihicRkQgMTyIiERieREQiMDyJiERgeBIRicDwJCISgeFJRCQCw5OISASGJxGRCAxPIiIRGJ5ERCIwPImIRGB4EhGJwPAkIhKB4UlEJALDk4hIBIYnEZEIDE8iIhEYnkREIjA8iYhEYHgSEYnA8CQiEoHhSUQkAsOTiEgEhicRkQgMTyIiERieREQi2BWe3d3dOHz4MDIyMrB27Vqo1Wo0NDRMO7asrAxbtmzB6tWr8cILL6CwsBBGo9FqXF9fH9577z3Ex8cjOjoar7/+Oq5duzZncxIRPQ67wvP27duoqKiAv78/EhISbI7TarXIycnBc889h5MnT2LXrl345JNPsG/fPotxRqMRmZmZuHLlCnJzc1FcXIyAgABkZmaiqanJ5XMSET0ub3sGxcbGor6+HgBQXV2NmpoaqzETExMoKChAcnIyDh48CABISEiAj48PcnNzkZmZiaioKADAZ599htbWVnz++edYtWoVACAuLg6vvPIKjh49itLSUpfNSUTkDHZteUqlsw9rbGyEXq9Henq6RXtaWhp8fHxQVVUltFVXV0OlUgkhBwCLFi1Camoq6urqMDw87LI5iYicwWkHjFpbWwEAERERFu2LFy9GeHi40D85VqVSWc2hVqsxMTGBjo4Ol81JROQMTgtPg8EAAJDJZFZ9MplM6J8ca2scAAwMDLhsTiIiZ7Brn6cjJBKJXe22xjky9nHmnM3SpYGzjlEoghya0xZz/wiCAv1s9vv7+0LxI3+n3NdsnFWTu2A97m+h1uS08JTL5QAebgEGBwdb9A0ODiIsLMxi7NStxqnjps7lijntdefOMEwms81+hSIIev1dh+a0ZcQ4jrvD9233jxihn5hwyn3NxJk1uQPW4/7coSapVGLXxpLV7Zy1AKVSCQAW+yEBYHR0FJ2dnRb7LZVKJXQ6ndUcLS0t8PLywsqVK102JxGRMzgtPKOjo6FQKKDVai3ay8vLMTY2hk2bNgltKSkp0Ol0aG5uFtoePHiAiooKJCYmIjAw0GVzEhE5g9fByRMoZ1FZWYm2tjZcv34d165dQ1hYGPr7+9HV1YXly5dDKpUiODgYJSUlGBgYgJ+fHy5duoT8/HwkJyfjjTfeEOZSq9W4cOECysrKoFAo0NvbiyNHjqClpQWFhYV46qmnAMAlc9prdPQBzLbftSMgwBcjIw8cmtOWsQkT/td3z2Z/qCIQi7xd/0laZ9bkDliP+3OHmiQSCfz9Fzl+O7N5poj4gVqtnrY9NDTU4qR5rVaL0tJS3Lp1C8HBwUhLS8OePXvg52d5QESv1yM/Px+1tbUwGo3QaDTIzs5GTEyM1X24Ys7ZzOU+z3vGcVxp7rHZH7dqGcwzrMXXxxvOyFZ32P/kTKzH/blDTWL3edodnk8adwrPKJUC13V6m/2xz4YgwPfxj/25wxPZmViP+3OHmub9gBER0ZOE4UlEJALDk4hIBIYnEZEIDE8iIhEYnkREIjA8iYhEYHgSEYnA8CQiEsHp1/Mka+MmwDg2brN/hg8yEZGbYnjOAePY7B+/JKKFhW/biYhEYHgSEYnA8CQiEoHhSUQkAsOTiEgEhicRkQgMTyIiERieREQiMDyJiERgeBIRicDwJCISgeFJRCQCw5OISASGJxGRCAxPIiIRGJ5ERCIwPImIRGB4EhGJ4NTwbGhogFqtnvZfe3u7xdivv/4aP//5z7FmzRokJiYiLy8PQ0NDVnPeu3cPhw8fxvr167FmzRq8+uqr+PLLL6e9f3vnJCJ6XC75DqOcnBzExsZatIWFhQk/NzQ0YOfOndiwYQPefvtt9Pb2orCwEDqdDmfPnoVU+kOmZ2VloampCTk5OQgLC8MXX3yBrKwsnDhxAklJSaLmJCJ6XC4JzxUrViA6Otpmf0FBASIiInDs2DEh1BQKBd58801UVlbipz/9KQCgtrYWdXV1KC4uRkpKCgAgISEBnZ2dOHLkiEV42jsnEZEzzPnmWE9PD27cuIGtW7dabA2uW7cOISEhqKqqEtouXryIoKAgbNiwQWiTSCRIT09HR0cH2traHJ6TiMgZXBKeeXl50Gg0eP7557Fr1y7cvHlT6NPpdACAiIgIq9upVCq0trYKv7e2tkKpVFq95Var1RZzOTInEZEzOPVte1BQELZv3464uDjI5XK0t7ejpKQEGRkZOHPmDKKiomAwGAAAMpnM6vYymQxNTU3C7waDAcuXL5923GT/1P/tmZOIyBmcGp4ajQYajUb4PSYmBsnJyUhNTUVRURE+/PBDoU8ikUw7x6PttsY5MnamOWxZujRw1jEKRZBdc5n7RxAU6Gez38fH+7H6/f19ofiRv11rmY29NS0UrMf9LdSaXHLAaCqFQoH169ejpqYGACCXywH8sLU41eDgoMXWo1wutzkO+GFL05E57XXnzjBMJrPNfoUiCHr9XbvmGjGO4+7wfZv9Y2OP1z8yYoR+YsKutczEkZoWAtbj/tyhJqlUYtfGktXtXLAWKyaTSfh5cr/kdPshdTqdxX5LpVKJ9vZ2i9tPjgMe7s90dE4iImdweXjq9XrU1dUJpy4tW7YMkZGRKCsrswjF+vp69PT0YNOmTUJbSkoKhoaGhK3WSefOncOKFSugVCodnpOIyBmc+rY9Ozsb4eHhWLVqFZYsWYKOjg6cPHkS9+/fxzvvvCOMy8nJwY4dO/DOO+9g27Zt6OnpQWFhIaKiorB582ZhXFJSEuLj43HgwAEYDAaEhYXh3LlzuHr1Ko4fP25x3/bOSUTkDE4NT7VajYqKCpw5cwajo6OQy+WIi4vD7t27hbfYAJCYmIgTJ07ggw8+wM6dOxEQEICNGzdi79698PLyEsZJJBIcP34cR48eRVFREYaGhqBUKlFcXIzk5GSL+7Z3TiIiZ5CYzWbbR0WeYM48YHTPOI4rzT02+6NUClzX6UX3xz4bggDfx/876A47752J9bg/d6jJrQ8YERF5GoYnEZEIDE8iIhEYnkREIrj8E0bkehKpBPeM4zOO8fXxhjf/VBI5DcPTAxjHJmY8Gg88PCLv7YQj8kT0ELdFiIhEYHgSEYnA8CQiEoHhSUQkAsOTiEgEhicRkQgMTyIiERieREQiMDyJiERgeBIRicDwJCISgeFJRCQCw5OISASGJxGRCAxPIiIRGJ5ERCIwPImIROClxZ8Qs31Vh68PnwpEjuAr5gkx21d1xD4bMoerIVr4+LadiEgEhicRkQgMTyIiERieREQieNQBo3v37qGoqAiVlZUYGhqCUqnE7373O2zYsGG+l+b2JFIJevtHMGLjiLyvjze8+aeWSOBR4ZmVlYWmpibk5OQgLCwMX3zxBbKysnDixAkkJSXN9/LcmnFsAs23e3F3+P60/bHPhsDb16OeLkSPxWNeDbW1tairq0NxcTFSUlIAAAkJCejs7MSRI0dcGp7jJsA4ZvscSpPZZXc9Z+w5T5RbpvQk8ZjwvHjxIoKCgizeokskEqSnpyM3NxdtbW1QKpUuuW/j2DiuNPfY7I9SKVxyv3PJnvNEuWVKTxKPeba3trZCqVRCKrXc/FGr1QAAnU7nUHhKpRK7x3h7SeHv52Nz3Hz32zvHYl9vTIxPP2bW2/t4wThumnkN3l4YH5+w2b/I2wteTt56tedxXEg8rR5g/msSe/8eE54GgwHLly+3apfJZEK/I4KDA2Yds3RpoPBz2P/JZhy7Mix4Xvvn6j7czdTHyBN4Wj3Awq3Jo/ZSSSS2/4LM1EdE5CiPCU+5XD7t1uXg4CCAH7ZAiYicwWPCU6lUor29HSaT5X43nU4HAFCpVPOxLCLyUB4TnikpKRgaGkJNTY1F+7lz57BixQqXHWknoieTxxwwSkpKQnx8PA4cOACDwYCwsDCcO3cOV69exfHjx+d7eUTkYSRms9kDTuF+aHh4GEePHkVVVZXFxzM3btw430sjIg/jUeFJRDRXPGafJxHRXGJ4EhGJ4DEHjObCQr7kXUNDA371q19N2/ePf/wDzzzzjPD7119/jffffx/fffcdAgICkJKSgpycHCxZsmSulmuhu7sbpaWl+Pbbb/Hdd99hZGQEH3/8MeLj463GlpWV4eTJk7h16xaCg4OxZcsWvPXWW/D19bUY19fXh4KCAnz11VcwGo3QaDTIycnBc88951Y1JScno6ury+r2v/nNb5CTk2PRNl811dfXQ6vV4r///S+6u7shk8mwZs0avPXWW8LHoyfZ+9xaCK81hqcDPOGSdzk5OYiNjbVoCwsLE35uaGjAzp07sWHDBrz99tvo7e1FYWEhdDodzp49a3XtgLlw+/ZtVFRUQKPRICEhwep0tElarRbvvvsuMjIysH//frS3t6OwsBBdXV0oKioSxhmNRmRmZmJkZAS5ubmQy+X46KOPkJmZiU8//RQajcZtagKA2NhYq6AMCbH8wr75rOmvf/0rDAYDMjMz8cwzz6Cvrw+lpaV47bXXcPr0aURHRwNw7Lm1IF5rZrLLV199ZVapVOYLFy4IbSaTyfyLX/zCvHnz5nlcmX0uX75sVqlU5osXL8447mc/+5l569at5omJCaHt3//+t1mlUpkrKipcvcxpTV3LxYsXzSqVynz58mWLMePj4+Z169aZf/vb31q0/+1vfzOrVCpzY2Oj0HbmzBmzSqUy37x5U2gzGo3m5ORk844dO1xUhSV7ajKbzeaXXnrJvHv37lnnm8+a+vr6rNoGBwfNMTEx5qysLKHN3ufWQnmtcZ+nnWa65F1HRwfa2trmcXXO0dPTgxs3bmDr1q0WWwHr1q1DSEgIqqqq5mVd9mztNjY2Qq/XIz093aI9LS0NPj4+Fmuvrq6GSqXCqlWrhLZFixYhNTUVdXV1GB4edt7ibXD2Fvx81rR06VKrtiVLluDpp59Gd3c3AMeeWwvltcbwtJM9l7xbCPLy8qDRaPD8889j165duHnzptA3WUNERITV7VQqFVpbW+dsnY6aXNuja1+8eDHCw8Mt1t7a2jrtx3XVajUmJibQ0dHh2sU66PLly1i7di0iIyORlpaGs2fPwvzIGYbuVlN/fz9aW1uFx8OR59ZCea1xn6ednH3Ju7kWFBSE7du3Iy4uDnK5HO3t7SgpKUFGRgbOnDmDqKgooYbpLqIik8nQ1NQ018u222xrn/r4GAwGm+MAYGBgwEWrdNyLL76IyMhIhIeHw2Aw4Pz58zh06BC+//577N+/XxjnTjWZzWbk5ubCZDJhx44dwvqmrufRNU59bi2U1xrD0wEL+ZJ3Go3G4qBBTEwMkpOTkZqaiqKiInz44YdCn61a3L1GwP61L5THMi8vz+L3lJQUZGdn4/Tp09i+fTtCQ0OFPnepKT8/H9XV1fjzn/9scRbHTOtYiI8P37bbyRMveadQKLB+/Xpcv34dwMMagen/sg8ODrp1jY6sfbbHcnIud5Weng6TyYRvvvlGaHOXmoqKinDq1CkcOHAAr776qsX6AOc8Pu7yPGR42slTL3k3tZ7J/VHT7dvU6XTT7q9yF5NXzXp07aOjo+js7LRYu1KpnHa/WUtLC7y8vLBy5UrXLvYxTT5mU/cJukNN77//Pk6cOIG9e/danVPsyHNrobzWGJ528sRL3un1etTV1Qnn4S1btgyRkZEoKyuzeOLW19ejp6cHmzZtmq+lzio6OhoKhQJardaivby8HGNjYxZrT0lJgU6nQ3Nzs9D24MEDVFRUIDExEYGB7v21EFqtFlKpFKtXrxba5rum4uJiHD9+HL///e/x61//2qrfkefWQnmteR08ePDgfC9iIXj66adx5coV/P3vf0dwcDCGhoZQXFyMf/7zn/jTn/6EFStWzPcSZ5SdnY3m5mbcvXsXfX19+Ne//oU//OEPuHv3LgoKCoSTrn/yk5/g1KlTaGtrg0wmw9WrV3Ho0CFERERg375983KSPABUVlaira0N169fx7Vr1xAWFob+/n50dXVh+fLlkEqlCA4ORklJCQYGBuDn54dLly4hPz8fycnJeOONN4S51Go1Lly4gLKyMigUCvT29uLIkSNoaWlBYWEhnnrqKbeoqby8HH/5y19w//59GAwGfPvtt8Knbt58801s3rzZLWo6deoUjh49ipdeegnp6eno7u4W/vX390OhePjtsfY+txbKa41XVXLAQr7kXUlJCSoqKtDV1YXR0VHI5XLExcVh9+7dVm+DLl26hA8++ED4CN3GjRuxd+/eed3X9OjH/CaFhoZabKFotVqUlpYKH89MS0vDnj174OfnZ3E7vV6P/Px81NbWCh9lzM7ORkxMjEvrmGq2mhobG3Hs2DG0tbXBYDDAx8cHarUa27ZtszqfFZi/mn75y1/iP//5z4y1TLL3ubUQXmsMTyIiEbjPk4hIBIYnEZEIDE8iIhEYnkREIjA8iYhEYHgSEYnA8CQiEoHhSUQkAsOTiEiE/wc0CczVU4cCPAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_name = 'HIV'\n",
    "tasks = ['HIV_active']\n",
    "raw_filename = \"../data/HIV.csv\"\n",
    "feature_filename = raw_filename.replace('.csv','.pickle')\n",
    "filename = raw_filename.replace('.csv','')\n",
    "prefix_filename = raw_filename.split('/')[-1].replace('.csv','')\n",
    "smiles_tasks_df = pd.read_csv(raw_filename)\n",
    "smilesList = smiles_tasks_df.smiles.values\n",
    "print(\"number of all smiles: \",len(smilesList))\n",
    "atom_num_dist = []\n",
    "remained_smiles = []\n",
    "canonical_smiles_list = []\n",
    "for smiles in smilesList:\n",
    "    try:        \n",
    "        mol = Chem.MolFromSmiles(smiles)\n",
    "        atom_num_dist.append(len(mol.GetAtoms()))\n",
    "        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))\n",
    "        remained_smiles.append(smiles)\n",
    "    except:\n",
    "        print(\"not successfully processed smiles: \", smiles)\n",
    "        pass\n",
    "print(\"number of successfully processed smiles: \", len(remained_smiles))\n",
    "smiles_tasks_df = smiles_tasks_df[smiles_tasks_df[\"smiles\"].isin(remained_smiles)]\n",
    "smiles_tasks_df['cano_smiles'] =canonical_smiles_list\n",
    "\n",
    "plt.figure(figsize=(5, 3))\n",
    "sns.set(font_scale=1.5)\n",
    "ax = sns.distplot(atom_num_dist, bins=28, kde=False)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"atom_num_dist_\"+prefix_filename+\".png\",dpi=200)\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# print(len([i for i in atom_num_dist if i<51]),len([i for i in atom_num_dist if i>50]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 8\n",
    "start_time = str(time.ctime()).replace(':','-').replace(' ','_')\n",
    "start = time.time()\n",
    "\n",
    "batch_size = 200\n",
    "epochs = 800\n",
    "p_dropout = 0.1\n",
    "fingerprint_dim = 150\n",
    "\n",
    "radius = 4\n",
    "T = 2\n",
    "weight_decay = 3.9 # also known as l2_regularization_lambda\n",
    "learning_rate = 3\n",
    "per_task_output_units_num = 2 # for classification model with 2 classes\n",
    "output_units_num = len(tasks) * per_task_output_units_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>activity</th>\n",
       "      <th>HIV_active</th>\n",
       "      <th>cano_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>71</th>\n",
       "      <td>C1CN[Co-4]23(N1)(NCCN2)NCCN3</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>C1CN[Co-4]23(N1)(NCCN2)NCCN3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>O=C1O[Cu-5]2(O)(O)(OC1=O)OC(=O)C(=O)O2</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1O[Cu-5]2(O)(O)(OC1=O)OC(=O)C(=O)O2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>CCc1cc[n+]([Mn](SC#N)(SC#N)([n+]2ccc(CC)cc2)([...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCc1cc[n+]([Mn](SC#N)(SC#N)([n+]2ccc(CC)cc2)([...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>O=C1O[Al]23(OC1=O)(OC(=O)C(=O)O2)OC(=O)C(=O)O3</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1O[Al]23(OC1=O)(OC(=O)C(=O)O2)OC(=O)C(=O)O3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>O=C1C[N+]23CC[N+]45CC(=O)O[Ni-4]24(O1)(OC(=O)C...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1C[N+]23CC[N+]45CC(=O)O[Ni-4]24(O1)(OC(=O)C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>248</th>\n",
       "      <td>CC1=[O+][Zr]234([O+]=C(C)C1)([O+]=C(C)CC(C)=[O...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[O+][Zr]234([O+]=C(C)C1)([O+]=C(C)CC(C)=[O...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>358</th>\n",
       "      <td>O=C1C[N+]23CC[N+]45CC(=O)O[Cu-5]24(O1)(OC(=O)C...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1C[N+]23CC[N+]45CC(=O)O[Cu-5]24(O1)(OC(=O)C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>676</th>\n",
       "      <td>c1ccc2c3c(ccc2c1)O[Fe-4]12(Oc4ccc5ccccc5c4N=[O...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>c1ccc2c3c(ccc2c1)O[Fe-4]12(Oc4ccc5ccccc5c4N=[O...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>823</th>\n",
       "      <td>C[N+]1(C)COC(=S)S[Fe-4]123(SC(=S)OC[N+]2(C)C)S...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>C[N+]1(C)COC(=S)S[Fe-4]123(SC(=S)OC[N+]2(C)C)S...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1561</th>\n",
       "      <td>NCC1OC(OC2C(CO)OC(OC3C(O)C(N)CC(N)C3OC3OC(CN)C...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>NCC1OC(OC2C(CO)OC(OC3C(O)C(N)CC(N)C3OC3OC(CN)C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1683</th>\n",
       "      <td>c1c[n+]([Ni-4]([n+]2cc[nH]c2)([n+]2cc[nH]c2)([...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>c1c[n+]([Ni-4]([n+]2cc[nH]c2)([n+]2cc[nH]c2)([...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2493</th>\n",
       "      <td>Cl[Pd-4]12([S+]=c3nc[nH]c4[nH]cnc34)([S+]=C3N=...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cl[Pd-4]12([S+]=c3nc[nH]c4[nH]cnc34)([S+]=C3N=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3046</th>\n",
       "      <td>CC1=[O+][Mn]23([O+]=C(C)C1)([O+]=C(C)CC(C)=[O+...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[O+][Mn]23([O+]=C(C)C1)([O+]=C(C)CC(C)=[O+...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3503</th>\n",
       "      <td>O=C(Nc1cc(C(=O)Nc2cc(C(=O)Nc3cc(C(=O)Nc4ccc(C5...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>O=C(Nc1cc(C(=O)Nc2cc(C(=O)Nc3cc(C(=O)Nc4ccc(C5...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3561</th>\n",
       "      <td>O=C1C[N+]23CCO[Fe-4]245(O1)OC(=O)C[N+]4(CC3)CC...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1C[N+]23CCO[Fe-4]245(O1)OC(=O)C[N+]4(CC3)CC...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4161</th>\n",
       "      <td>NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4162</th>\n",
       "      <td>NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4703</th>\n",
       "      <td>CC(C)(C)OC(=O)CNC(=O)C(Cc1ccccc1)NC(=O)C(CSC(c...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC(C)(C)OC(=O)CNC(=O)C(Cc1ccccc1)NC(=O)C(CSC(c...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5191</th>\n",
       "      <td>Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5599</th>\n",
       "      <td>COC1C=COC2(C)Oc3c(C)c(O)c4c(O)c(c(C=NNC(=O)CC(...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>COC1C=COC2(C)Oc3c(C)c(O)c4c(O)c(c(C=NNC(=O)CC(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5927</th>\n",
       "      <td>Cl[Sn](Cl)(C12C3=C4C5=C1[Fe]45321678C2=C1C6C7=...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cl[Sn](Cl)(C12C3=C4C5=C1[Fe]45321678C2=C1C6C7=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5999</th>\n",
       "      <td>Br[Ni-4]12(Br)(NCCN1)NCCN2</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Br[Ni-4]12(Br)(NCCN1)NCCN2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6001</th>\n",
       "      <td>C1CN[Ni-4]23(N1)(NCCN2)NCCN3.[O-][Cl+3]([O-])(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>C1CN[Ni-4]23(N1)(NCCN2)NCCN3.[O-][Cl+3]([O-])(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6669</th>\n",
       "      <td>Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CC2C.F[P-...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CC2C.F[P-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6670</th>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)Cc2ccccc21.F[P-](...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)Cc2ccccc21.F[P-](...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6671</th>\n",
       "      <td>Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CCC2C.F[P...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CCC2C.F[P...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6672</th>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-](F...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-](F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6673</th>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)c2c1ccc1ccccc21.F...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1CC[P+](c2ccccc2)(c2ccccc2)c2c1ccc1ccccc21.F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6676</th>\n",
       "      <td>CC1(C)CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1(C)CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7018</th>\n",
       "      <td>Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36539</th>\n",
       "      <td>C1=NN=C2N3CN4CCN5CN6CN7CCN(C3)[Cu-5]457([NH+]1...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>C1=NN=C2N3CN4CCN5CN6CN7CCN(C3)[Cu-5]457([NH+]1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36814</th>\n",
       "      <td>CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)[OH+]2)[OH+]...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)[OH+]2)[OH+]...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36815</th>\n",
       "      <td>CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)C(=O)[OH+]2)...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)C(=O)[OH+]2)...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36816</th>\n",
       "      <td>CC1=[O+][Co-4]23(NCCN2CCCNC(=O)c2cccc4cc5ccccc...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[O+][Co-4]23(NCCN2CCCNC(=O)c2cccc4cc5ccccc...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37443</th>\n",
       "      <td>CC1=[N+]2N=C(c3ccncc3)[OH+][Cu-5]234(O)[OH+]C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[N+]2N=C(c3ccncc3)[OH+][Cu-5]234(O)[OH+]C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37444</th>\n",
       "      <td>CC1=[N+]2N=C(c3ccccc3)[OH+][Cu-5]234(O)[OH+]C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[N+]2N=C(c3ccccc3)[OH+][Cu-5]234(O)[OH+]C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37445</th>\n",
       "      <td>CC1=[N+]2N=C(c3ccccc3)[OH+][Ni-4]234(O)[OH+]C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC1=[N+]2N=C(c3ccccc3)[OH+][Ni-4]234(O)[OH+]C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37484</th>\n",
       "      <td>CCCN1CCN[Cr]123([O+]=C(C)[CH-]C(C)=[O+]2)[O+]=...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCCN1CCN[Cr]123([O+]=C(C)[CH-]C(C)=[O+]2)[O+]=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38247</th>\n",
       "      <td>CC(C)[N+]12CC[N+]3(C(C)C)CC[N+](C(C)C)(CC1)[Mo...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC(C)[N+]12CC[N+]3(C(C)C)CC[N+](C(C)C)(CC1)[Mo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38268</th>\n",
       "      <td>COC(c1ccc(O)cc1)C1NC(=O)C(CCC(N)=O)N(C)C(=O)C(...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>COC(c1ccc(O)cc1)C1NC(=O)C(CCC(N)=O)N(C)C(=O)C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38398</th>\n",
       "      <td>CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38399</th>\n",
       "      <td>CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38400</th>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38401</th>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38402</th>\n",
       "      <td>CCC(C)C(NC(=O)C(N)CCCNC(=N)N)C(=O)NC(CCC(N)=O)...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)C(N)CCCNC(=N)N)C(=O)NC(CCC(N)=O)...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38404</th>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCC(N)=O)NC(=O)CNC(=O)C(NC(=O...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC(C)C(NC(=O)C(CCCC(N)=O)NC(=O)CNC(=O)C(NC(=O...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39028</th>\n",
       "      <td>O=C1C[N+]23CC(=O)[OH+][Zr]2456([OH+]1)([OH+]C(...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>O=C1C[N+]23CC(=O)[OH+][Zr]2456([OH+]1)([OH+]C(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39029</th>\n",
       "      <td>[Cl-].c1cc2ccc3ccc[n+]4c3c2[n+](c1)[Co-4]412([...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>[Cl-].c1cc2ccc3ccc[n+]4c3c2[n+](c1)[Co-4]412([...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39030</th>\n",
       "      <td>C1CN[Co-4]23(N1)(NCCN2)NCCN3.O=C([O-])C(O)C(O)...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>C1CN[Co-4]23(N1)(NCCN2)NCCN3.O=C([O-])C(O)C(O)...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39121</th>\n",
       "      <td>COc1cc2cc(C(=O)N3CC(CCl)c4ccc(N5CCN[Cr]567([O+...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>COc1cc2cc(C(=O)N3CC(CCl)c4ccc(N5CCN[Cr]567([O+...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40156</th>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "      <td>CA</td>\n",
       "      <td>1</td>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40157</th>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "      <td>CA</td>\n",
       "      <td>1</td>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40158</th>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40741</th>\n",
       "      <td>COC(=O)Cc1ccc(COC(=O)c2cc(C(=CCCC3CCC4(C)C(CCC...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>COC(=O)Cc1ccc(COC(=O)c2cc(C(=CCCC3CCC4(C)C(CCC...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40743</th>\n",
       "      <td>CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40746</th>\n",
       "      <td>CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40975</th>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41000</th>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41027</th>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "      <td>CI</td>\n",
       "      <td>0</td>\n",
       "      <td>CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41030</th>\n",
       "      <td>CCCCCCCCCCCCCCCC(=O)Nc1ccn(C2CCC(COP(=O)(O)OCC...</td>\n",
       "      <td>CM</td>\n",
       "      <td>1</td>\n",
       "      <td>CCCCCCCCCCCCCCCC(=O)Nc1ccn(C2CCC(COP(=O)(O)OCC...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>379 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  smiles activity  HIV_active  \\\n",
       "71                          C1CN[Co-4]23(N1)(NCCN2)NCCN3       CI           0   \n",
       "79                O=C1O[Cu-5]2(O)(O)(OC1=O)OC(=O)C(=O)O2       CI           0   \n",
       "88     CCc1cc[n+]([Mn](SC#N)(SC#N)([n+]2ccc(CC)cc2)([...       CI           0   \n",
       "137       O=C1O[Al]23(OC1=O)(OC(=O)C(=O)O2)OC(=O)C(=O)O3       CI           0   \n",
       "138    O=C1C[N+]23CC[N+]45CC(=O)O[Ni-4]24(O1)(OC(=O)C...       CI           0   \n",
       "248    CC1=[O+][Zr]234([O+]=C(C)C1)([O+]=C(C)CC(C)=[O...       CI           0   \n",
       "358    O=C1C[N+]23CC[N+]45CC(=O)O[Cu-5]24(O1)(OC(=O)C...       CI           0   \n",
       "676    c1ccc2c3c(ccc2c1)O[Fe-4]12(Oc4ccc5ccccc5c4N=[O...       CM           1   \n",
       "823    C[N+]1(C)COC(=S)S[Fe-4]123(SC(=S)OC[N+]2(C)C)S...       CI           0   \n",
       "1561   NCC1OC(OC2C(CO)OC(OC3C(O)C(N)CC(N)C3OC3OC(CN)C...       CI           0   \n",
       "1683   c1c[n+]([Ni-4]([n+]2cc[nH]c2)([n+]2cc[nH]c2)([...       CI           0   \n",
       "2493   Cl[Pd-4]12([S+]=c3nc[nH]c4[nH]cnc34)([S+]=C3N=...       CI           0   \n",
       "3046   CC1=[O+][Mn]23([O+]=C(C)C1)([O+]=C(C)CC(C)=[O+...       CI           0   \n",
       "3503   O=C(Nc1cc(C(=O)Nc2cc(C(=O)Nc3cc(C(=O)Nc4ccc(C5...       CM           1   \n",
       "3561   O=C1C[N+]23CCO[Fe-4]245(O1)OC(=O)C[N+]4(CC3)CC...       CI           0   \n",
       "4161   NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...       CI           0   \n",
       "4162   NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...       CI           0   \n",
       "4703   CC(C)(C)OC(=O)CNC(=O)C(Cc1ccccc1)NC(=O)C(CSC(c...       CI           0   \n",
       "5191   Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...       CI           0   \n",
       "5599   COC1C=COC2(C)Oc3c(C)c(O)c4c(O)c(c(C=NNC(=O)CC(...       CM           1   \n",
       "5927   Cl[Sn](Cl)(C12C3=C4C5=C1[Fe]45321678C2=C1C6C7=...       CI           0   \n",
       "5999                          Br[Ni-4]12(Br)(NCCN1)NCCN2       CI           0   \n",
       "6001   C1CN[Ni-4]23(N1)(NCCN2)NCCN3.[O-][Cl+3]([O-])(...       CI           0   \n",
       "6669   Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CC2C.F[P-...       CI           0   \n",
       "6670   CC1CC[P+](c2ccccc2)(c2ccccc2)Cc2ccccc21.F[P-](...       CI           0   \n",
       "6671   Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CCC2C.F[P...       CI           0   \n",
       "6672   CC1CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-](F...       CI           0   \n",
       "6673   CC1CC[P+](c2ccccc2)(c2ccccc2)c2c1ccc1ccccc21.F...       CI           0   \n",
       "6676   CC1(C)CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-...       CI           0   \n",
       "7018   Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...       CI           0   \n",
       "...                                                  ...      ...         ...   \n",
       "36539  C1=NN=C2N3CN4CCN5CN6CN7CCN(C3)[Cu-5]457([NH+]1...       CI           0   \n",
       "36814  CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)[OH+]2)[OH+]...       CI           0   \n",
       "36815  CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)C(=O)[OH+]2)...       CI           0   \n",
       "36816  CC1=[O+][Co-4]23(NCCN2CCCNC(=O)c2cccc4cc5ccccc...       CI           0   \n",
       "37443  CC1=[N+]2N=C(c3ccncc3)[OH+][Cu-5]234(O)[OH+]C(...       CI           0   \n",
       "37444  CC1=[N+]2N=C(c3ccccc3)[OH+][Cu-5]234(O)[OH+]C(...       CI           0   \n",
       "37445  CC1=[N+]2N=C(c3ccccc3)[OH+][Ni-4]234(O)[OH+]C(...       CI           0   \n",
       "37484  CCCN1CCN[Cr]123([O+]=C(C)[CH-]C(C)=[O+]2)[O+]=...       CI           0   \n",
       "38247  CC(C)[N+]12CC[N+]3(C(C)C)CC[N+](C(C)C)(CC1)[Mo...       CI           0   \n",
       "38268  COC(c1ccc(O)cc1)C1NC(=O)C(CCC(N)=O)N(C)C(=O)C(...       CM           1   \n",
       "38398  CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...       CI           0   \n",
       "38399  CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...       CI           0   \n",
       "38400  CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...       CI           0   \n",
       "38401  CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...       CI           0   \n",
       "38402  CCC(C)C(NC(=O)C(N)CCCNC(=N)N)C(=O)NC(CCC(N)=O)...       CI           0   \n",
       "38404  CCC(C)C(NC(=O)C(CCCC(N)=O)NC(=O)CNC(=O)C(NC(=O...       CI           0   \n",
       "39028  O=C1C[N+]23CC(=O)[OH+][Zr]2456([OH+]1)([OH+]C(...       CI           0   \n",
       "39029  [Cl-].c1cc2ccc3ccc[n+]4c3c2[n+](c1)[Co-4]412([...       CI           0   \n",
       "39030  C1CN[Co-4]23(N1)(NCCN2)NCCN3.O=C([O-])C(O)C(O)...       CI           0   \n",
       "39121  COc1cc2cc(C(=O)N3CC(CCl)c4ccc(N5CCN[Cr]567([O+...       CI           0   \n",
       "40156  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...       CA           1   \n",
       "40157  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...       CA           1   \n",
       "40158  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...       CI           0   \n",
       "40741  COC(=O)Cc1ccc(COC(=O)c2cc(C(=CCCC3CCC4(C)C(CCC...       CI           0   \n",
       "40743  CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...       CM           1   \n",
       "40746  CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...       CM           1   \n",
       "40975  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...       CI           0   \n",
       "41000  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...       CI           0   \n",
       "41027  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...       CI           0   \n",
       "41030  CCCCCCCCCCCCCCCC(=O)Nc1ccn(C2CCC(COP(=O)(O)OCC...       CM           1   \n",
       "\n",
       "                                             cano_smiles  \n",
       "71                          C1CN[Co-4]23(N1)(NCCN2)NCCN3  \n",
       "79                O=C1O[Cu-5]2(O)(O)(OC1=O)OC(=O)C(=O)O2  \n",
       "88     CCc1cc[n+]([Mn](SC#N)(SC#N)([n+]2ccc(CC)cc2)([...  \n",
       "137       O=C1O[Al]23(OC1=O)(OC(=O)C(=O)O2)OC(=O)C(=O)O3  \n",
       "138    O=C1C[N+]23CC[N+]45CC(=O)O[Ni-4]24(O1)(OC(=O)C...  \n",
       "248    CC1=[O+][Zr]234([O+]=C(C)C1)([O+]=C(C)CC(C)=[O...  \n",
       "358    O=C1C[N+]23CC[N+]45CC(=O)O[Cu-5]24(O1)(OC(=O)C...  \n",
       "676    c1ccc2c3c(ccc2c1)O[Fe-4]12(Oc4ccc5ccccc5c4N=[O...  \n",
       "823    C[N+]1(C)COC(=S)S[Fe-4]123(SC(=S)OC[N+]2(C)C)S...  \n",
       "1561   NCC1OC(OC2C(CO)OC(OC3C(O)C(N)CC(N)C3OC3OC(CN)C...  \n",
       "1683   c1c[n+]([Ni-4]([n+]2cc[nH]c2)([n+]2cc[nH]c2)([...  \n",
       "2493   Cl[Pd-4]12([S+]=c3nc[nH]c4[nH]cnc34)([S+]=C3N=...  \n",
       "3046   CC1=[O+][Mn]23([O+]=C(C)C1)([O+]=C(C)CC(C)=[O+...  \n",
       "3503   O=C(Nc1cc(C(=O)Nc2cc(C(=O)Nc3cc(C(=O)Nc4ccc(C5...  \n",
       "3561   O=C1C[N+]23CCO[Fe-4]245(O1)OC(=O)C[N+]4(CC3)CC...  \n",
       "4161   NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...  \n",
       "4162   NCCCCC(=O)NC(CCCCNC(=O)OCc1ccccc1)C(=O)NCCCCC(...  \n",
       "4703   CC(C)(C)OC(=O)CNC(=O)C(Cc1ccccc1)NC(=O)C(CSC(c...  \n",
       "5191   Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...  \n",
       "5599   COC1C=COC2(C)Oc3c(C)c(O)c4c(O)c(c(C=NNC(=O)CC(...  \n",
       "5927   Cl[Sn](Cl)(C12C3=C4C5=C1[Fe]45321678C2=C1C6C7=...  \n",
       "5999                          Br[Ni-4]12(Br)(NCCN1)NCCN2  \n",
       "6001   C1CN[Ni-4]23(N1)(NCCN2)NCCN3.[O-][Cl+3]([O-])(...  \n",
       "6669   Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CC2C.F[P-...  \n",
       "6670   CC1CC[P+](c2ccccc2)(c2ccccc2)Cc2ccccc21.F[P-](...  \n",
       "6671   Cc1ccc2c(c1)C[P+](c1ccccc1)(c1ccccc1)CCC2C.F[P...  \n",
       "6672   CC1CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-](F...  \n",
       "6673   CC1CC[P+](c2ccccc2)(c2ccccc2)c2c1ccc1ccccc21.F...  \n",
       "6676   CC1(C)CC[P+](c2ccccc2)(c2ccccc2)c2ccccc21.F[P-...  \n",
       "7018   Cc1c(N)nc(C(CC(N)=O)NCC(N)C(N)=O)nc1C(=O)NC(C(...  \n",
       "...                                                  ...  \n",
       "36539  C1=NN=C2N3CN4CCN5CN6CN7CCN(C3)[Cu-5]457([NH+]1...  \n",
       "36814  CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)[OH+]2)[OH+]...  \n",
       "36815  CC[N+]1(CC)CCN[Co-4]123([OH+]C(=O)C(=O)[OH+]2)...  \n",
       "36816  CC1=[O+][Co-4]23(NCCN2CCCNC(=O)c2cccc4cc5ccccc...  \n",
       "37443  CC1=[N+]2N=C(c3ccncc3)[OH+][Cu-5]234(O)[OH+]C(...  \n",
       "37444  CC1=[N+]2N=C(c3ccccc3)[OH+][Cu-5]234(O)[OH+]C(...  \n",
       "37445  CC1=[N+]2N=C(c3ccccc3)[OH+][Ni-4]234(O)[OH+]C(...  \n",
       "37484  CCCN1CCN[Cr]123([O+]=C(C)[CH-]C(C)=[O+]2)[O+]=...  \n",
       "38247  CC(C)[N+]12CC[N+]3(C(C)C)CC[N+](C(C)C)(CC1)[Mo...  \n",
       "38268  COC(c1ccc(O)cc1)C1NC(=O)C(CCC(N)=O)N(C)C(=O)C(...  \n",
       "38398  CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...  \n",
       "38399  CCC(C)C(NC(=O)CNC(=O)C(Cc1c[nH]cn1)NC(=O)C(NC(...  \n",
       "38400  CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...  \n",
       "38401  CCC(C)C(NC(=O)C(CCCCN)NC(=O)CNC(=O)C(NC(=O)C(N...  \n",
       "38402  CCC(C)C(NC(=O)C(N)CCCNC(=N)N)C(=O)NC(CCC(N)=O)...  \n",
       "38404  CCC(C)C(NC(=O)C(CCCC(N)=O)NC(=O)CNC(=O)C(NC(=O...  \n",
       "39028  O=C1C[N+]23CC(=O)[OH+][Zr]2456([OH+]1)([OH+]C(...  \n",
       "39029  [Cl-].c1cc2ccc3ccc[n+]4c3c2[n+](c1)[Co-4]412([...  \n",
       "39030  C1CN[Co-4]23(N1)(NCCN2)NCCN3.O=C([O-])C(O)C(O)...  \n",
       "39121  COc1cc2cc(C(=O)N3CC(CCl)c4ccc(N5CCN[Cr]567([O+...  \n",
       "40156  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...  \n",
       "40157  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...  \n",
       "40158  CCC1NC(=O)C(C(O)C(C)C)N(C)C(=O)C(C(C)C)N(C)C(=...  \n",
       "40741  COC(=O)Cc1ccc(COC(=O)c2cc(C(=CCCC3CCC4(C)C(CCC...  \n",
       "40743  CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...  \n",
       "40746  CC1OC(OC2C(O)COC(OC3C(C)OC(OC4C(OC(=O)C56CCC(C...  \n",
       "40975  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...  \n",
       "41000  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...  \n",
       "41027  CC(C)CCCC(C)C1CCC2C3CCC4CC(CCC=C(c5cc(Cl)c(OCc...  \n",
       "41030  CCCCCCCCCCCCCCCC(=O)Nc1ccn(C2CCC(COP(=O)(O)OCC...  \n",
       "\n",
       "[379 rows x 4 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smilesList = [smiles for smiles in canonical_smiles_list if len(Chem.MolFromSmiles(smiles).GetAtoms())<101]\n",
    "\n",
    "if os.path.isfile(feature_filename):\n",
    "    feature_dicts = pickle.load(open(feature_filename, \"rb\" ))\n",
    "else:\n",
    "    feature_dicts = save_smiles_dicts(smilesList,filename)\n",
    "# feature_dicts = get_smiles_dicts(smilesList)\n",
    "\n",
    "remained_df = smiles_tasks_df[smiles_tasks_df[\"cano_smiles\"].isin(feature_dicts['smiles_to_atom_mask'].keys())]\n",
    "uncovered_df = smiles_tasks_df.drop(remained_df.index)\n",
    "uncovered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = []\n",
    "for i,task in enumerate(tasks):    \n",
    "    negative_df = remained_df[remained_df[task] == 0][[\"smiles\",task]]\n",
    "    positive_df = remained_df[remained_df[task] == 1][[\"smiles\",task]]\n",
    "    weights.append([(positive_df.shape[0]+negative_df.shape[0])/negative_df.shape[0],\\\n",
    "                    (positive_df.shape[0]+negative_df.shape[0])/positive_df.shape[0]])\n",
    "\n",
    "test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set\n",
    "training_data = remained_df.drop(test_df.index) # training data\n",
    "\n",
    "# training data is further divided into validation set and train set\n",
    "valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set\n",
    "train_df = training_data.drop(valid_df.index) # train set\n",
    "train_df = train_df.reset_index(drop=True)\n",
    "valid_df = valid_df.reset_index(drop=True)\n",
    "test_df = test_df.reset_index(drop=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "808057\n",
      "atom_fc.weight torch.Size([150, 39])\n",
      "atom_fc.bias torch.Size([150])\n",
      "neighbor_fc.weight torch.Size([150, 49])\n",
      "neighbor_fc.bias torch.Size([150])\n",
      "GRUCell.0.weight_ih torch.Size([450, 150])\n",
      "GRUCell.0.weight_hh torch.Size([450, 150])\n",
      "GRUCell.0.bias_ih torch.Size([450])\n",
      "GRUCell.0.bias_hh torch.Size([450])\n",
      "GRUCell.1.weight_ih torch.Size([450, 150])\n",
      "GRUCell.1.weight_hh torch.Size([450, 150])\n",
      "GRUCell.1.bias_ih torch.Size([450])\n",
      "GRUCell.1.bias_hh torch.Size([450])\n",
      "GRUCell.2.weight_ih torch.Size([450, 150])\n",
      "GRUCell.2.weight_hh torch.Size([450, 150])\n",
      "GRUCell.2.bias_ih torch.Size([450])\n",
      "GRUCell.2.bias_hh torch.Size([450])\n",
      "GRUCell.3.weight_ih torch.Size([450, 150])\n",
      "GRUCell.3.weight_hh torch.Size([450, 150])\n",
      "GRUCell.3.bias_ih torch.Size([450])\n",
      "GRUCell.3.bias_hh torch.Size([450])\n",
      "align.0.weight torch.Size([1, 300])\n",
      "align.0.bias torch.Size([1])\n",
      "align.1.weight torch.Size([1, 300])\n",
      "align.1.bias torch.Size([1])\n",
      "align.2.weight torch.Size([1, 300])\n",
      "align.2.bias torch.Size([1])\n",
      "align.3.weight torch.Size([1, 300])\n",
      "align.3.bias torch.Size([1])\n",
      "attend.0.weight torch.Size([150, 150])\n",
      "attend.0.bias torch.Size([150])\n",
      "attend.1.weight torch.Size([150, 150])\n",
      "attend.1.bias torch.Size([150])\n",
      "attend.2.weight torch.Size([150, 150])\n",
      "attend.2.bias torch.Size([150])\n",
      "attend.3.weight torch.Size([150, 150])\n",
      "attend.3.bias torch.Size([150])\n",
      "mol_GRUCell.weight_ih torch.Size([450, 150])\n",
      "mol_GRUCell.weight_hh torch.Size([450, 150])\n",
      "mol_GRUCell.bias_ih torch.Size([450])\n",
      "mol_GRUCell.bias_hh torch.Size([450])\n",
      "mol_align.weight torch.Size([1, 300])\n",
      "mol_align.bias torch.Size([1])\n",
      "mol_attend.weight torch.Size([150, 150])\n",
      "mol_attend.bias torch.Size([150])\n",
      "output.weight torch.Size([2, 150])\n",
      "output.bias torch.Size([2])\n"
     ]
    }
   ],
   "source": [
    "x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)\n",
    "num_atom_features = x_atom.shape[-1]\n",
    "num_bond_features = x_bonds.shape[-1]\n",
    "\n",
    "loss_function = [nn.CrossEntropyLoss(torch.Tensor(weight),reduction='mean') for weight in weights]\n",
    "model = Fingerprint(radius, T, num_atom_features,num_bond_features,\n",
    "            fingerprint_dim, output_units_num, p_dropout)\n",
    "model.cuda()\n",
    "# tensorboard = SummaryWriter(log_dir=\"runs/\"+start_time+\"_\"+prefix_filename+\"_\"+str(fingerprint_dim)+\"_\"+str(p_dropout))\n",
    "\n",
    "# optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)\n",
    "optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)\n",
    "model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "params = sum([np.prod(p.size()) for p in model_parameters])\n",
    "print(params)\n",
    "for name, param in model.named_parameters():\n",
    "    if param.requires_grad:\n",
    "        print(name, param.data.shape)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataset, optimizer, loss_function):\n",
    "    model.train()\n",
    "    np.random.seed(epoch)\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    #shuffle them\n",
    "    np.random.shuffle(valList)\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, train_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[train_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "#         print(torch.Tensor(x_atom).size(),torch.Tensor(x_bonds).size(),torch.cuda.LongTensor(x_atom_index).size(),torch.cuda.LongTensor(x_bond_index).size(),torch.Tensor(x_mask).size())\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss = 0.0\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where(y_val != -1)[0]\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "\n",
    "            loss += loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "        # Step 5. Do the backward pass and update the gradient\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "def eval(model, dataset):\n",
    "    model.eval()\n",
    "    y_val_list = {}\n",
    "    y_pred_list = {}\n",
    "    losses_list = []\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, eval_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[eval_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "        atom_pred = atoms_prediction.data[:,:,1].unsqueeze(2).cpu().numpy()\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where((y_val=='0') | (y_val=='1'))[0]\n",
    "#             print(validInds)\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "#             print(validInds)\n",
    "            loss = loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "#             print(y_pred_adjust)\n",
    "            y_pred_adjust = F.softmax(y_pred_adjust,dim=-1).data.cpu().numpy()[:,1]\n",
    "            losses_list.append(loss.cpu().detach().numpy())\n",
    "            try:\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "            except:\n",
    "                y_val_list[i] = []\n",
    "                y_pred_list[i] = []\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "                \n",
    "    eval_roc = [roc_auc_score(y_val_list[i], y_pred_list[i]) for i in range(len(tasks))]\n",
    "#     eval_prc = [auc(precision_recall_curve(y_val_list[i], y_pred_list[i])[1],precision_recall_curve(y_val_list[i], y_pred_list[i])[0]) for i in range(len(tasks))]\n",
    "#     eval_precision = [precision_score(y_val_list[i],\n",
    "#                                      (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "#     eval_recall = [recall_score(y_val_list[i],\n",
    "#                                (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "    eval_loss = np.array(losses_list).mean()\n",
    "    \n",
    "    return eval_roc, eval_loss #eval_prc, eval_precision, eval_recall, \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t0\n",
      "train_roc:[0.6423051131555713]\n",
      "valid_roc:[0.6131278115549044]\n",
      "\n",
      "EPOCH:\t1\n",
      "train_roc:[0.717566333880532]\n",
      "valid_roc:[0.7171916391560532]\n",
      "\n",
      "EPOCH:\t2\n",
      "train_roc:[0.7391916463403431]\n",
      "valid_roc:[0.7336779046963855]\n",
      "\n",
      "EPOCH:\t3\n",
      "train_roc:[0.7530537846947918]\n",
      "valid_roc:[0.7367379485644597]\n",
      "\n",
      "EPOCH:\t4\n",
      "train_roc:[0.7669964628414634]\n",
      "valid_roc:[0.7376760642028474]\n",
      "\n",
      "EPOCH:\t5\n",
      "train_roc:[0.7695541618620173]\n",
      "valid_roc:[0.750159703019392]\n",
      "\n",
      "EPOCH:\t6\n",
      "train_roc:[0.7695492716565855]\n",
      "valid_roc:[0.759918339267287]\n",
      "\n",
      "EPOCH:\t7\n",
      "train_roc:[0.7849213622545306]\n",
      "valid_roc:[0.7734160364167556]\n",
      "\n",
      "EPOCH:\t8\n",
      "train_roc:[0.7910773608885152]\n",
      "valid_roc:[0.774242471622002]\n",
      "\n",
      "EPOCH:\t9\n",
      "train_roc:[0.7982776480327325]\n",
      "valid_roc:[0.7728397653817459]\n",
      "\n",
      "EPOCH:\t10\n",
      "train_roc:[0.7967904582095472]\n",
      "valid_roc:[0.7736125939790846]\n",
      "\n",
      "EPOCH:\t11\n",
      "train_roc:[0.8022484678256905]\n",
      "valid_roc:[0.7767016747597753]\n",
      "\n",
      "EPOCH:\t12\n",
      "train_roc:[0.8016605354477754]\n",
      "valid_roc:[0.7929668130424877]\n",
      "\n",
      "EPOCH:\t13\n",
      "train_roc:[0.8058380137012208]\n",
      "valid_roc:[0.7711980630145675]\n",
      "\n",
      "EPOCH:\t14\n",
      "train_roc:[0.8163520634505506]\n",
      "valid_roc:[0.7908381839868128]\n",
      "\n",
      "EPOCH:\t15\n",
      "train_roc:[0.8102610964434968]\n",
      "valid_roc:[0.7716514855731218]\n",
      "\n",
      "EPOCH:\t16\n",
      "train_roc:[0.821288712325155]\n",
      "valid_roc:[0.7949569583610672]\n",
      "\n",
      "EPOCH:\t17\n",
      "train_roc:[0.8172597774389156]\n",
      "valid_roc:[0.7902328760391865]\n",
      "\n",
      "EPOCH:\t18\n",
      "train_roc:[0.8222601745992464]\n",
      "valid_roc:[0.7961675742563199]\n",
      "\n",
      "EPOCH:\t19\n",
      "train_roc:[0.8289279697055827]\n",
      "valid_roc:[0.798593273264151]\n",
      "\n",
      "EPOCH:\t20\n",
      "train_roc:[0.8306198186787476]\n",
      "valid_roc:[0.7943181462834985]\n",
      "\n",
      "EPOCH:\t21\n",
      "train_roc:[0.8323310798762373]\n",
      "valid_roc:[0.7914725288470559]\n",
      "\n",
      "EPOCH:\t22\n",
      "train_roc:[0.8290816869641713]\n",
      "valid_roc:[0.7880841445055461]\n",
      "\n",
      "EPOCH:\t23\n",
      "train_roc:[0.8377049158182207]\n",
      "valid_roc:[0.7817630319897433]\n",
      "\n",
      "EPOCH:\t24\n",
      "train_roc:[0.8334745774160046]\n",
      "valid_roc:[0.8056403085953728]\n",
      "\n",
      "EPOCH:\t25\n",
      "train_roc:[0.8415103872286206]\n",
      "valid_roc:[0.7927881243494612]\n",
      "\n",
      "EPOCH:\t26\n",
      "train_roc:[0.8483380326136176]\n",
      "valid_roc:[0.8005789513654049]\n",
      "\n",
      "EPOCH:\t27\n",
      "train_roc:[0.851649795908222]\n",
      "valid_roc:[0.7845438747749638]\n",
      "\n",
      "EPOCH:\t28\n",
      "train_roc:[0.8559138929385435]\n",
      "valid_roc:[0.7947335974947844]\n",
      "\n",
      "EPOCH:\t29\n",
      "train_roc:[0.857453064834938]\n",
      "valid_roc:[0.792482119962654]\n",
      "\n",
      "EPOCH:\t30\n",
      "train_roc:[0.8563731669834942]\n",
      "valid_roc:[0.7875123406878622]\n",
      "\n",
      "EPOCH:\t31\n",
      "train_roc:[0.8632574820143107]\n",
      "valid_roc:[0.7919415866662498]\n",
      "\n",
      "EPOCH:\t32\n",
      "train_roc:[0.866450475457653]\n",
      "valid_roc:[0.7906840649890776]\n",
      "\n",
      "EPOCH:\t33\n",
      "train_roc:[0.867732425250103]\n",
      "valid_roc:[0.7906438600331468]\n",
      "\n",
      "EPOCH:\t34\n",
      "train_roc:[0.870706332086556]\n",
      "valid_roc:[0.7932795182552835]\n",
      "\n",
      "EPOCH:\t35\n",
      "train_roc:[0.8739417973693396]\n",
      "valid_roc:[0.7801570673611701]\n",
      "\n",
      "EPOCH:\t36\n",
      "train_roc:[0.8664285235686288]\n",
      "valid_roc:[0.7879255582904854]\n",
      "\n",
      "EPOCH:\t37\n",
      "train_roc:[0.8790251524068996]\n",
      "valid_roc:[0.7951557495320589]\n",
      "\n",
      "EPOCH:\t38\n",
      "train_roc:[0.8746871619409089]\n",
      "valid_roc:[0.7913720164572287]\n",
      "\n",
      "EPOCH:\t39\n",
      "train_roc:[0.881575516119306]\n",
      "valid_roc:[0.7918589431457251]\n",
      "\n",
      "EPOCH:\t40\n",
      "train_roc:[0.8771064897619091]\n",
      "valid_roc:[0.7948676140145542]\n",
      "\n",
      "EPOCH:\t41\n",
      "train_roc:[0.8843242438553623]\n",
      "valid_roc:[0.7979030881873371]\n",
      "\n",
      "EPOCH:\t42\n",
      "train_roc:[0.8927245901143281]\n",
      "valid_roc:[0.8056403085953728]\n",
      "\n",
      "EPOCH:\t43\n",
      "train_roc:[0.8900884451740426]\n",
      "valid_roc:[0.7758752395545291]\n",
      "\n",
      "EPOCH:\t44\n",
      "train_roc:[0.8940874174203708]\n",
      "valid_roc:[0.7934515061223213]\n",
      "\n",
      "EPOCH:\t45\n",
      "train_roc:[0.8988878024767675]\n",
      "valid_roc:[0.791392118935194]\n",
      "\n",
      "EPOCH:\t46\n",
      "train_roc:[0.8916221189665402]\n",
      "valid_roc:[0.7965137835990584]\n",
      "\n",
      "EPOCH:\t47\n",
      "train_roc:[0.8963365606888005]\n",
      "valid_roc:[0.8004203651503442]\n",
      "\n",
      "EPOCH:\t48\n",
      "train_roc:[0.8999403313884186]\n",
      "valid_roc:[0.7910905817657121]\n",
      "\n",
      "EPOCH:\t49\n",
      "train_roc:[0.9090100414019381]\n",
      "valid_roc:[0.8015617391770492]\n",
      "\n",
      "EPOCH:\t50\n",
      "train_roc:[0.9077852205563703]\n",
      "valid_roc:[0.801577374437689]\n",
      "\n",
      "EPOCH:\t51\n",
      "train_roc:[0.9034628868530745]\n",
      "valid_roc:[0.8031632365882967]\n",
      "\n",
      "EPOCH:\t52\n",
      "train_roc:[0.9137630373957522]\n",
      "valid_roc:[0.7855467650645738]\n",
      "\n",
      "EPOCH:\t53\n",
      "train_roc:[0.911943786413118]\n",
      "valid_roc:[0.7843975734075487]\n",
      "\n",
      "EPOCH:\t54\n",
      "train_roc:[0.9226550924600057]\n",
      "valid_roc:[0.7980214694464671]\n",
      "\n",
      "EPOCH:\t55\n",
      "train_roc:[0.9151515990701586]\n",
      "valid_roc:[0.7843953397988859]\n",
      "\n",
      "EPOCH:\t56\n",
      "train_roc:[0.9266048924749194]\n",
      "valid_roc:[0.7987697283485145]\n",
      "\n",
      "EPOCH:\t57\n",
      "train_roc:[0.9320034901477222]\n",
      "valid_roc:[0.7908917905947206]\n",
      "\n",
      "EPOCH:\t58\n",
      "train_roc:[0.929411316529759]\n",
      "valid_roc:[0.7964467753391736]\n",
      "\n",
      "EPOCH:\t59\n",
      "train_roc:[0.9204188501478948]\n",
      "valid_roc:[0.7948140074066463]\n",
      "\n",
      "EPOCH:\t60\n",
      "train_roc:[0.9325041688325864]\n",
      "valid_roc:[0.7850129325941578]\n",
      "\n",
      "EPOCH:\t61\n",
      "train_roc:[0.9414015869121892]\n",
      "valid_roc:[0.8049434226925705]\n",
      "\n",
      "EPOCH:\t62\n",
      "train_roc:[0.9416257393396223]\n",
      "valid_roc:[0.7992923927756163]\n",
      "\n",
      "EPOCH:\t63\n",
      "train_roc:[0.9310294882089312]\n",
      "valid_roc:[0.8143804192930182]\n",
      "\n",
      "EPOCH:\t64\n",
      "train_roc:[0.9394698747134772]\n",
      "valid_roc:[0.7988412038257249]\n",
      "\n",
      "EPOCH:\t65\n",
      "train_roc:[0.9475296762521899]\n",
      "valid_roc:[0.7916891888873501]\n",
      "\n",
      "EPOCH:\t66\n",
      "train_roc:[0.9523877441185149]\n",
      "valid_roc:[0.811336010685584]\n",
      "\n",
      "EPOCH:\t67\n",
      "train_roc:[0.9506153013177078]\n",
      "valid_roc:[0.8150974076737859]\n",
      "\n",
      "EPOCH:\t68\n",
      "train_roc:[0.9545148078662602]\n",
      "valid_roc:[0.8058614358529927]\n",
      "\n",
      "EPOCH:\t69\n",
      "train_roc:[0.9527876602555442]\n",
      "valid_roc:[0.7968220215945286]\n",
      "\n",
      "EPOCH:\t70\n",
      "train_roc:[0.9510175950131793]\n",
      "valid_roc:[0.8115772404211693]\n",
      "\n",
      "EPOCH:\t71\n",
      "train_roc:[0.9520714342837445]\n",
      "valid_roc:[0.818534931405878]\n",
      "\n",
      "EPOCH:\t72\n",
      "train_roc:[0.9542219088768307]\n",
      "valid_roc:[0.8218987460520967]\n",
      "\n",
      "EPOCH:\t73\n",
      "train_roc:[0.9589680288635595]\n",
      "valid_roc:[0.825334036175526]\n",
      "\n",
      "EPOCH:\t74\n",
      "train_roc:[0.9609700141248585]\n",
      "valid_roc:[0.8240564120203885]\n",
      "\n",
      "EPOCH:\t75\n",
      "train_roc:[0.9634746503638205]\n",
      "valid_roc:[0.8184679231459931]\n",
      "\n",
      "EPOCH:\t76\n",
      "train_roc:[0.9635862335043375]\n",
      "valid_roc:[0.8247331954452253]\n",
      "\n",
      "EPOCH:\t77\n",
      "train_roc:[0.9647926417808346]\n",
      "valid_roc:[0.8240631128463769]\n",
      "\n",
      "EPOCH:\t78\n",
      "train_roc:[0.9688192531440508]\n",
      "valid_roc:[0.808349675903383]\n",
      "\n",
      "EPOCH:\t79\n",
      "train_roc:[0.9605623168874734]\n",
      "valid_roc:[0.8169066306906765]\n",
      "\n",
      "EPOCH:\t80\n",
      "train_roc:[0.9690614669103305]\n",
      "valid_roc:[0.8126538397966524]\n",
      "\n",
      "EPOCH:\t81\n",
      "train_roc:[0.9708341798882332]\n",
      "valid_roc:[0.8198080883436898]\n",
      "\n",
      "EPOCH:\t82\n",
      "train_roc:[0.9654475645696241]\n",
      "valid_roc:[0.8172193359034725]\n",
      "\n",
      "EPOCH:\t83\n",
      "train_roc:[0.9691188930520178]\n",
      "valid_roc:[0.8104336327858014]\n",
      "\n",
      "EPOCH:\t84\n",
      "train_roc:[0.9687796586906785]\n",
      "valid_roc:[0.8125466265808365]\n",
      "\n",
      "EPOCH:\t85\n",
      "train_roc:[0.9726041640774694]\n",
      "valid_roc:[0.82110358136813]\n",
      "\n",
      "EPOCH:\t86\n",
      "train_roc:[0.9721164268754883]\n",
      "valid_roc:[0.8114298222494225]\n",
      "\n",
      "EPOCH:\t87\n",
      "train_roc:[0.9748418113104776]\n",
      "valid_roc:[0.8116263798117515]\n",
      "\n",
      "EPOCH:\t88\n",
      "train_roc:[0.9678358895688947]\n",
      "valid_roc:[0.8148249074169208]\n",
      "\n",
      "EPOCH:\t89\n",
      "train_roc:[0.9753885957167204]\n",
      "valid_roc:[0.821816102531572]\n",
      "\n",
      "EPOCH:\t90\n",
      "train_roc:[0.9755214553035222]\n",
      "valid_roc:[0.8126136348407214]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "best_param ={}\n",
    "best_param[\"roc_epoch\"] = 0\n",
    "best_param[\"loss_epoch\"] = 0\n",
    "best_param[\"valid_roc\"] = 0\n",
    "best_param[\"valid_loss\"] = 9e8\n",
    "\n",
    "for epoch in range(epochs):    \n",
    "    train_roc, train_loss = eval(model, train_df)\n",
    "    valid_roc, valid_loss = eval(model, valid_df)\n",
    "    train_roc_mean = np.array(train_roc).mean()\n",
    "    valid_roc_mean = np.array(valid_roc).mean()\n",
    "    \n",
    "#     tensorboard.add_scalars('ROC',{'train_roc':train_roc_mean,'valid_roc':valid_roc_mean},epoch)\n",
    "#     tensorboard.add_scalars('Losses',{'train_losses':train_loss,'valid_losses':valid_loss},epoch)\n",
    "\n",
    "    if valid_roc_mean > best_param[\"valid_roc\"]:\n",
    "        best_param[\"roc_epoch\"] = epoch\n",
    "        best_param[\"valid_roc\"] = valid_roc_mean\n",
    "        if valid_roc_mean > 0.80:\n",
    "             torch.save(model, 'saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(epoch)+'.pt')             \n",
    "    if valid_loss < best_param[\"valid_loss\"]:\n",
    "        best_param[\"loss_epoch\"] = epoch\n",
    "        best_param[\"valid_loss\"] = valid_loss\n",
    "\n",
    "    print(\"EPOCH:\\t\"+str(epoch)+'\\n'\\\n",
    "        +\"train_roc\"+\":\"+str(train_roc)+'\\n'\\\n",
    "        +\"valid_roc\"+\":\"+str(valid_roc)+'\\n'\\\n",
    "#         +\"train_roc_mean\"+\":\"+str(train_roc_mean)+'\\n'\\\n",
    "#         +\"valid_roc_mean\"+\":\"+str(valid_roc_mean)+'\\n'\\\n",
    "        )\n",
    "    if (epoch - best_param[\"roc_epoch\"] >16) and (epoch - best_param[\"loss_epoch\"] >18):        \n",
    "        break\n",
    "        \n",
    "    train(model, train_df, optimizer, loss_function)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "best epoch:73\n",
      "test_roc:[0.8483737648847225]\n",
      "test_roc_mean: 0.8483737648847225\n"
     ]
    }
   ],
   "source": [
    "# evaluate model\n",
    "best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param[\"roc_epoch\"])+'.pt')     \n",
    "\n",
    "# best_model_dict = best_model.state_dict()\n",
    "# best_model_wts = copy.deepcopy(best_model_dict)\n",
    "\n",
    "# model.load_state_dict(best_model_wts)\n",
    "# (best_model.align[0].weight == model.align[0].weight).all()\n",
    "test_roc, test_losses = eval(best_model, test_df)\n",
    "\n",
    "print(\"best epoch:\"+str(best_param[\"roc_epoch\"])\n",
    "      +\"\\n\"+\"test_roc:\"+str(test_roc)\n",
    "      +\"\\n\"+\"test_roc_mean:\",str(np.array(test_roc).mean())\n",
    "     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
